(*
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 *)

module Char = struct
  include Char

  let is_lowercase = function 'a' .. 'z' | '_' -> true | _ -> false

  let is_uppercase = function 'A' .. 'Z' -> true | _ -> false
end

module String = struct
  include CCStringLabels

  let lsplit2 str ~on =
    match index_opt str on with
    | Some pos ->
        Some (sub str ~pos:0 ~len:pos, sub str ~pos:(pos + 1) ~len:(length str - pos - 1))
    | None ->
        None


  let subo ?(pos = 0) ?len str =
    let len = match len with Some i -> i | None -> length str - pos in
    sub str ~pos ~len
end

module Map = Map.Make (String)

(** Debug trace logging *)

type ('a, 'b) fmt = ('a, Format.formatter, unit, 'b) format4

type 'a printf = ('a, unit) fmt -> 'a

type pf = {pf: 'a. 'a printf}

let fs = Format.err_formatter

let flush = Format.pp_print_newline fs

type trace_mod_funs = {trace_mod: bool option; trace_funs: bool Map.t}

type trace_mods_funs = trace_mod_funs Map.t

type config = {trace_all: bool; trace_mods_funs: trace_mods_funs; colors: bool}

let none = {trace_all= false; trace_mods_funs= Map.empty; colors= false}

let all = {none with trace_all= true}

let config = ref none

exception Parse_failure of string

let parse_exn s =
  if String.equal s "*" then all
  else
    let default = Map.empty in
    let index_from s i =
      match (String.index_from_opt s i '+', String.index_from_opt s i '-') with
      | None, o | o, None ->
          o
      | Some m, Some n ->
          Some (min m n)
    in
    let rec split s rev_parts i =
      match index_from s (i + 1) with
      | Some j when j = i ->
          split s rev_parts j
      | Some j ->
          split s (String.sub s ~pos:i ~len:(j - i) :: rev_parts) j
      | _ ->
          List.rev (String.subo s ~pos:i :: rev_parts)
    in
    let parts = split s [] 0 in
    let trace_mods_funs =
      List.fold_left
        (fun m part ->
          let parse_part part =
            let sign, rest =
              match part.[0] with
              | '-' ->
                  (false, String.subo part ~pos:1)
              | '+' ->
                  (true, String.subo part ~pos:1)
              | _ ->
                  (true, part)
            in
            if String.is_empty rest then raise (Parse_failure ("missing module name after: " ^ part)) ;
            if not (Char.is_uppercase rest.[0]) then
              raise (Parse_failure ("module name must be capitalized: " ^ rest)) ;
            match String.lsplit2 rest ~on:'.' with
            | Some (mod_name, fun_name) ->
                if not (Char.is_lowercase fun_name.[0]) then
                  raise (Parse_failure ("function name must not be capitalized: " ^ fun_name)) ;
                (mod_name, Some fun_name, sign)
            | None ->
                (rest, None, sign)
          in
          match parse_part part with
          | mod_name, Some fun_name, enabled ->
              let {trace_mod; trace_funs} =
                try Map.find mod_name m with Not_found -> {trace_mod= None; trace_funs= default}
              in
              Map.add mod_name {trace_mod; trace_funs= Map.add fun_name enabled trace_funs} m
          | mod_name, None, enabled ->
              Map.add mod_name {trace_mod= Some enabled; trace_funs= default} m )
        default parts
    in
    {none with trace_mods_funs}


let pp_styled style fmt fs =
  Format.pp_open_box fs 2 ;
  if not !config.colors then Format.kfprintf (fun fs -> Format.pp_close_box fs ()) fs fmt
  else (
    ( match style with
    | `Bold ->
        Format.fprintf fs "@<0>\027[1m"
    | `Cyan ->
        Format.fprintf fs "@<0>\027[36m"
    | `Magenta ->
        Format.fprintf fs "@<0>\027[95m" ) ;
    Format.kfprintf
      (fun fs ->
        Format.fprintf fs "@<0>\027[0m" ;
        Format.pp_close_box fs () )
      fs fmt )


let init ?(colors = false) ?(margin = 240) ?config:(c = none) () =
  Format.set_margin margin ;
  Format.set_max_indent (margin - 1) ;
  Format.pp_set_margin fs margin ;
  Format.pp_set_max_indent fs (margin - 1) ;
  Format.pp_open_vbox fs 0 ;
  at_exit flush ;
  config := {c with colors}


(** split a string such as [Dune__exe__Module.Submodule.Subsubmodule.function.subfunction] into
    [(Module, function.subfunction)] *)
let split_mod_fun_name s =
  let rec chop_anon s =
    match String.chop_suffix s ~suf:".(fun)" with Some s -> chop_anon s | None -> s
  in
  let s = chop_anon s in
  let fun_name_end = String.length s in
  let rec fun_name_start_ s i =
    match String.rindex_from_opt s i '.' with
    | Some j ->
        if Char.is_uppercase s.[j + 1] then fun_name_start_ s j else j + 1
    | None ->
        0
  in
  let fun_name_start = fun_name_start_ s (fun_name_end - 1) in
  let fun_name = String.sub s ~pos:fun_name_start ~len:(fun_name_end - fun_name_start) in
  let mod_name_end =
    match String.index_from_opt s 0 '.' with Some i -> i | None -> fun_name_end
  in
  let rec mod_name_start_ s i =
    if i <= 1 then None
    else if not (Char.equal '_' s.[i]) then mod_name_start_ s (i - 1)
    else if not (Char.equal '_' s.[i - 1]) then mod_name_start_ s (i - 2)
    else Some (i + 1)
  in
  let mod_name_start =
    match mod_name_start_ s (mod_name_end - 2) with Some pos -> pos | None -> 0
  in
  let mod_name = String.sub s ~pos:mod_name_start ~len:(mod_name_end - mod_name_start) in
  (mod_name, fun_name)


let enabled mod_fun_name =
  let {trace_all; trace_mods_funs; _} = !config in
  if Map.is_empty trace_mods_funs then trace_all
  else
    let mod_name, fun_name = split_mod_fun_name mod_fun_name in
    match Map.find mod_name trace_mods_funs with
    | {trace_mod; trace_funs} -> (
      try Map.find fun_name trace_funs
      with Not_found -> (
        match trace_mod with Some mod_enabled -> mod_enabled | None -> trace_all ) )
    | exception Not_found ->
        trace_all


let kprintf mod_fun_name k fmt =
  if enabled mod_fun_name then Format.kfprintf k fs fmt else Format.ifprintf fs fmt


let fprintf mod_fun_name fs fmt =
  if enabled mod_fun_name then Format.fprintf fs fmt else Format.ifprintf fs fmt


let printf mod_fun_name fmt = fprintf mod_fun_name fs fmt

let info mod_fun_name fmt =
  if not (enabled mod_fun_name) then Format.ifprintf fs fmt
  else
    let mod_name, fun_name = split_mod_fun_name mod_fun_name in
    Format.fprintf fs "@\n@[<4>| %s.%s:" mod_name fun_name ;
    Format.kfprintf (fun fs -> Format.fprintf fs "@]") fs fmt


let infok_ mod_fun_name fmt =
  if not (enabled mod_fun_name) then Format.ifprintf fs fmt
  else
    let mod_name, _ = split_mod_fun_name mod_fun_name in
    Format.fprintf fs "@\n@[<4>| %s." mod_name ;
    Format.kfprintf (fun fs -> Format.fprintf fs "@]") fs fmt


let infok mod_fun_name k = k {pf= (fun fmt -> infok_ mod_fun_name fmt)}

let incf mod_fun_name fmt =
  if not (enabled mod_fun_name) then Format.ifprintf fs fmt
  else
    let mod_name, fun_name = split_mod_fun_name mod_fun_name in
    Format.fprintf fs "@\n@[<2>@[<hv 2>( %s.%s:" mod_name fun_name ;
    Format.kfprintf (fun fs -> Format.fprintf fs "@]") fs fmt


let decf mod_fun_name fmt =
  if not (enabled mod_fun_name) then Format.ifprintf fs fmt
  else
    let mod_name, fun_name = split_mod_fun_name mod_fun_name in
    Format.fprintf fs "@]@\n@[<2>) %s.%s:@ " mod_name fun_name ;
    Format.kfprintf (fun fs -> Format.fprintf fs "@]") fs fmt


let call mod_fun_name k = k {pf= (fun fmt -> incf mod_fun_name fmt)}

let retn mod_fun_name k result =
  k {pf= (fun fmt -> decf mod_fun_name fmt)} result ;
  result


let dbgs :
       (('s -> 'r * 's) -> 'n)
    -> ('m -> 's -> 'r * 's)
    -> ?call:(pf -> unit)
    -> ?retn:(pf -> 'r * ('s * 's) -> unit)
    -> ?rais:(pf -> 's -> exn -> Printexc.raw_backtrace -> unit)
    -> string
    -> 'm
    -> 'n =
 fun thunk force ?call ?retn ?rais mod_fun_name m ->
  thunk
  @@ fun s ->
  let call = Option.value call ~default:(fun {pf} -> pf "") in
  let retn = Option.value retn ~default:(fun {pf} _ -> pf "") in
  let rais = Option.value rais ~default:(fun {pf} _ exc _ -> pf "%s" (Printexc.to_string exc)) in
  call {pf= (fun fmt -> incf mod_fun_name fmt)} ;
  match force m s with
  | result, s' ->
      retn {pf= (fun fmt -> decf mod_fun_name fmt)} (result, (s, s')) ;
      (result, s')
  | exception exc ->
      let bt = Printexc.get_raw_backtrace () in
      rais {pf= (fun fmt -> decf mod_fun_name fmt)} s exc bt ;
      Printexc.raise_with_backtrace exc bt


let dbg :
       ?call:(pf -> unit)
    -> ?retn:(pf -> 'a -> unit)
    -> ?rais:(pf -> exn -> Printexc.raw_backtrace -> unit)
    -> string
    -> (unit -> 'a)
    -> 'a =
 fun ?call ?retn ?rais mod_fun_name k ->
  let call = Option.map (fun call pf -> call pf) call in
  let retn = Option.map (fun retn pf (r, _) -> retn pf r) retn in
  let rais = Option.map (fun rais pf _ -> rais pf) rais in
  dbgs (fun k -> fst (k ())) (fun k () -> (k (), ())) ?call ?retn ?rais mod_fun_name k


let raisef ?margin exn fmt =
  let fs = Format.str_formatter in
  ( match margin with
  | Some m ->
      Format.pp_set_margin fs m ;
      Format.pp_set_max_indent fs (m - 1)
  | None ->
      () ) ;
  Format.pp_open_box fs 2 ;
  Format.kfprintf
    (fun fs () ->
      Format.pp_close_box fs () ;
      raise (exn (Format.flush_str_formatter ())) )
    fs fmt


let fail fmt =
  let margin = Format.pp_get_margin fs () in
  raisef ~margin
    (fun msg ->
      Format.fprintf fs "@\n%s@." msg ;
      Failure msg )
    fmt
